# functions.py

import torch
import numpy as np
from scipy.integrate import quad
import torch
import time

# Define functions c(u) and a(u)
#def c(u):
#    return -u**2 / 2
def a(u):
    return 1 / 2
    

def c(u_values, M=3, N=3, num_points=50):
    device = u_values.device
    original_shape = u_values.shape
    u_values_flat = u_values.reshape(-1, u_values.size(-1))  # Flatten to (num_u_values, n1)
    
    t_values = torch.linspace(-M, N, num_points).to(device).unsqueeze(0)  # [1, num_points]
    b = -2*t_values/2
    a = 1 / 2
    tensor_0 = (b / a).to(device)  # [1, num_points]
    dt = (M + N) / num_points
    
    u_values_expanded = u_values_flat.unsqueeze(2)  # [num_u_values, n1, 1]
    
    positive_mask = u_values_expanded >= 0
    negative_mask = u_values_expanded < 0
    
    mask_positive = (t_values > 0) & (t_values <= u_values_expanded)  # [num_u_values, n1, num_points]
    mask_negative = (t_values <= 0) & (t_values > u_values_expanded)  # [num_u_values, n1, num_points]
    
    cumsum_positive = torch.cumsum(tensor_0 * mask_positive.float(), dim=-1) * dt
    cumsum_negative = torch.cumsum(tensor_0 * mask_negative.float(), dim=-1) * dt
    
    result_positive = cumsum_positive[:, :, -1]  # Get the last column
    result_negative = -cumsum_negative[:, :, -1]  # Negative part result is negated
    
    result = torch.where(positive_mask[:, :, -1], result_positive, result_negative)
    
    result = result.view(original_shape)
    return result



def f1(t, x, y, M, N, num_points=100):#100->200
    device = t.device
    u_values = torch.linspace(-M, N, num_points, device=device)
    delta_u = u_values[1] - u_values[0]
    integrand = torch.exp(c(u_values)) / a(u_values)

    def compute_integral(start, end):
        start = start.unsqueeze(-1)
        end = end.unsqueeze(-1)
        mask = (u_values >= start) & (u_values <= end)
        integral = torch.sum(integrand * mask, dim=-1) * delta_u
        return integral

    integral_y_N = compute_integral(y, torch.tensor(N, device=device))
    integral_M_x = compute_integral(torch.tensor(-M, device=device), x)
    integral_M_t = compute_integral(torch.tensor(-M, device=device), t)
    integral_t_N = compute_integral(t, torch.tensor(N, device=device))

    result = torch.zeros_like(t)
    condition1 = t <= x
    condition2 = (x < t) & (t < y)
    condition3 = y <= t

    result[condition1] = (torch.sqrt(integral_y_N[condition1] / integral_M_x[condition1]) * 
                          torch.sqrt(integral_M_t[condition1]))
    result[condition2] = torch.sqrt(integral_y_N[condition2])
    result[condition3] = torch.sqrt(integral_t_N[condition3])
    
    # Optionally clear more memory if necessary
    del integral_y_N, integral_M_x, integral_M_t, integral_t_N
    torch.cuda.empty_cache()  # Free up memory again if needed

    return result

def f2_minus(tensor, M, N, num_points=100, chunk_size=10, u_chunk_size=50):#numpoints100->200
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device
    u_values = torch.linspace(-M, theta_tensor.max().item(), num_points, device=device)
    t_values = torch.linspace(-M, theta_tensor.max().item(), num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]
    results = []

    for i in range(0, x_tensor.shape[0], chunk_size):
        x_chunk = x_tensor[i:i + chunk_size]
        y_chunk = y_tensor[i:i + chunk_size]
        theta_chunk = theta_tensor[i:i + chunk_size]
        
        chunk_results = []
        for u_start in range(0, num_points, u_chunk_size):
            for t_start in range(0, num_points, u_chunk_size):
                # Create chunks for u_values and t_values
                u_chunk = u_values[u_start:u_start + u_chunk_size]
                t_chunk = t_values[t_start:t_start + u_chunk_size]
                
                inner1 = (torch.exp(c(u_chunk)) / a(u_chunk)) 
                inner2 =  torch.exp(-c(t_chunk))
                inner1_grid,inner2_grid = torch.meshgrid(inner1,inner2,indexing='ij')
                inner1_grid = inner1_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)
                inner2_grid = inner2_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)

                # Generate meshgrid for the chunked u and t values
                u_grid, t_grid = torch.meshgrid(u_chunk, t_chunk, indexing='ij')
                u_grid = u_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)
                t_grid = t_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)
                inner_integral_mask = (t_grid >= u_grid) & (t_grid <= theta_chunk.unsqueeze(-1).unsqueeze(-1))
                integrand_inner = inner1_grid * inner2_grid
                integrand_inner_masked = integrand_inner * inner_integral_mask
                
                # Compute f1 values for the chunked grid
                f1_values = f1(t_grid[:x_chunk.shape[0]], x_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size), 
                               y_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size), M, N, num_points)

                # Perform inner sum for this chunk
                inner_sum = torch.sum(integrand_inner_masked * f1_values, dim=(-2, -1)) * delta_u * delta_t
                chunk_results.append(inner_sum)


                # Clear GPU memory for the chunk
                del u_chunk, t_chunk, u_grid, t_grid, inner_integral_mask, integrand_inner, integrand_inner_masked, f1_values, inner_sum
                torch.cuda.empty_cache()
        
        # Concatenate results for this chunk of x, y, theta
        results.append(torch.sum(torch.stack(chunk_results), dim=0))
        
        # Clear memory for x, y, theta chunks
        del x_chunk, y_chunk, theta_chunk, chunk_results
        torch.cuda.empty_cache()

    return torch.cat(results, dim=0)



def f2_plus(tensor, M, N, num_points=100, chunk_size=10, u_chunk_size=10):
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device
    u_values = torch.linspace(theta_tensor.min().item(), N, num_points, device=device)
    t_values = torch.linspace(theta_tensor.min().item(), N, num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]
    results = []

    for i in range(0, x_tensor.shape[0], chunk_size):
        x_chunk = x_tensor[i:i + chunk_size]
        y_chunk = y_tensor[i:i + chunk_size]
        theta_chunk = theta_tensor[i:i + chunk_size]
        
        chunk_results = []
        
        for u_start in range(0, num_points, u_chunk_size):
            for t_start in range(0, num_points, u_chunk_size):
                # Create chunks for u_values and t_values
                u_chunk = u_values[u_start:u_start + u_chunk_size]
                t_chunk = t_values[t_start:t_start + u_chunk_size]
                
                inner1 = (torch.exp(c(u_chunk)) / a(u_chunk)) 
                inner2 =  torch.exp(-c(t_chunk))
                inner1_grid,inner2_grid = torch.meshgrid(inner1,inner2,indexing='ij')
                inner1_grid = inner1_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)
                inner2_grid = inner2_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)
                
                # Generate meshgrid for the chunked u and t values
                u_grid, t_grid = torch.meshgrid(u_chunk, t_chunk, indexing='ij')
                u_grid = u_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)
                t_grid = t_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)

                inner_integral_mask = (t_grid >= theta_chunk.unsqueeze(-1).unsqueeze(-1)) & (t_grid <= u_grid)
                integrand_inner = inner1_grid*inner2_grid
                integrand_inner_masked = integrand_inner * inner_integral_mask
                
                # Compute f1 values for the chunked grid
                f1_values = f1(t_grid[:x_chunk.shape[0]], x_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size), 
                               y_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size), M, N, num_points)
                
                # Perform inner sum for this chunk
                inner_sum = torch.sum(integrand_inner_masked * f1_values, dim=(-2, -1)) * delta_u * delta_t
                chunk_results.append(inner_sum)

                # Clear GPU memory for the chunk
                del u_chunk, t_chunk, u_grid, t_grid, inner_integral_mask, integrand_inner, integrand_inner_masked, f1_values, inner_sum
                torch.cuda.empty_cache()
        
        # Concatenate results for this chunk of x, y, theta
        chunk_result = torch.sum(torch.stack(chunk_results), dim=0)
        results.append(chunk_result)
        
        # Clear memory for x, y, theta chunks
        del x_chunk, y_chunk, theta_chunk, chunk_results
        torch.cuda.empty_cache()

    # Concatenate results and ensure the final shape is [n0, n0, n1]
    return torch.cat(results, dim=0).squeeze(-1)




def f_u_minus(result_tensor, M, N, num_intervals=800, chunk_size=20):
    # Extract (x, y, theta, u) from result_tensor
    x = result_tensor[..., 0]
    y = result_tensor[..., 1]
    theta = result_tensor[..., 2]
    u = result_tensor[..., 3]

    # Get device information
    device = result_tensor.device

    # Generate t_values from 0 to 1 for intervals
    t_values = torch.linspace(0, 1, num_intervals).view(-1, 1, 1, 1).to(device)

    # Use broadcasting to generate t_values within the range [u, theta]
    u_tensor = u.unsqueeze(0)  # Shape: (1, num_x, num_y, numpoints_f2_minus_z)
    t_values = u_tensor + (theta.unsqueeze(0) - u_tensor) * t_values  # Shape: (num_intervals, num_x, num_y, numpoints_f2_minus_z)

    # Calculate dt
    dt = (theta - u) / num_intervals  # Shape: (num_x, num_y, numpoints_f2_minus_z)

    # Calculate exp(c(u)) and exp(-c(t))
    exp_c_u = torch.exp(c(u)).unsqueeze(0)  # Shape: (1, num_x, num_y, numpoints_f2_minus_z)
    
    # Initialize an empty list to store the chunked results
    integrand_results = []

    # Chunk processing along the first dimension of t_values
    for t_chunk in torch.split(t_values, chunk_size, dim=0):
        # Ensure gradients are not required for these computations
        with torch.no_grad():
            exp_minus_c_t = torch.exp(-c(t_chunk))  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)

            # Expand x and y to match the shape of t_chunk
            f1_chunk = f1(t_chunk, 
                          x=x.unsqueeze(0).expand(t_chunk.shape), 
                          y=y.unsqueeze(0).expand(t_chunk.shape), 
                          M=M, N=N)  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)

            # Compute the integrand for the chunk
            integrand_chunk = exp_minus_c_t * f1_chunk  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)
            integrand_results.append(integrand_chunk)
            

    
    # Concatenate all chunks along the interval dimension
    integrand = torch.cat(integrand_results, dim=0)  # Shape: (num_intervals, num_x, num_y, numpoints_f2_minus_z)

    # Clean up to release memory
    del t_values, exp_minus_c_t, integrand_results
    torch.cuda.empty_cache()

    # Sum over the t dimension and multiply by dt to get the integral value
    integral_value = torch.sum(integrand, dim=0) * dt.unsqueeze(0)  # Shape: (num_x, num_y, numpoints_f2_minus_z)

    # Finally, compute f_u
    f_u = (exp_c_u / a(u)) * integral_value.squeeze()  # Shape: (num_x, num_y, numpoints_f2_minus_z)
    # Remove the dimension with size 1
    f_u = f_u.squeeze(0)  # Now shape: [num_x, num_y, numpoints_f2_minus_z]

    print(f_u.shape)
    
    
    # Combine the results, adding f_u as the fifth dimension
    result_tensor_extended = torch.cat((result_tensor, f_u.unsqueeze(-1)), dim=-1)  # Shape: (num_x, num_y, numpoints_f2_minus_z, 5)

    return result_tensor_extended
    
    
def f_u_plus(result_tensor, M, N, num_intervals=800, chunk_size=20):
    # Extract (x, y, theta, u) from result_tensor
    x = result_tensor[..., 0]
    y = result_tensor[..., 1]
    theta = result_tensor[..., 2]
    u = result_tensor[..., 3]

    # Get device information
    device = result_tensor.device

    # Generate t_values from 0 to 1 for intervals
    t_values = torch.linspace(0, 1, num_intervals).view(-1, 1, 1, 1).to(device)

    # Use broadcasting to generate t_values within the range [theta, u]
    theta_tensor = theta.unsqueeze(0)  # Shape: (1, num_x, num_y, numpoints_f2_minus_z)
    t_values = theta_tensor + (u.unsqueeze(0) - theta_tensor) * t_values  # Shape: (num_intervals, num_x, num_y, numpoints_f2_minus_z)

    # Calculate dt
    dt = (u - theta) / num_intervals  # Shape: (num_x, num_y, numpoints_f2_minus_z)

    # Calculate exp(c(u)) and exp(-c(t))
    exp_c_u = torch.exp(c(u)).unsqueeze(0)  # Shape: (1, num_x, num_y, numpoints_f2_minus_z)
    
    # Initialize an empty list to store the chunked results
    integrand_results = []

    # Chunk processing along the first dimension of t_values
    for t_chunk in torch.split(t_values, chunk_size, dim=0):
        # Ensure gradients are not required for these computations
        with torch.no_grad():
            exp_minus_c_t = torch.exp(-c(t_chunk))  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)

            # Expand x and y to match the shape of t_chunk
            f1_chunk = f1(t_chunk, 
                          x=x.unsqueeze(0).expand(t_chunk.shape), 
                          y=y.unsqueeze(0).expand(t_chunk.shape), 
                          M=M, N=N)  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)

            # Compute the integrand for the chunk
            integrand_chunk = exp_minus_c_t * f1_chunk  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)
            integrand_results.append(integrand_chunk)
            

    
    # Concatenate all chunks along the interval dimension
    integrand = torch.cat(integrand_results, dim=0)  # Shape: (num_intervals, num_x, num_y, numpoints_f2_minus_z)

    # Clean up to release memory
    del t_values, exp_minus_c_t, integrand_results
    torch.cuda.empty_cache()

    # Sum over the t dimension and multiply by dt to get the integral value
    integral_value = torch.sum(integrand, dim=0) * dt.unsqueeze(0)  # Shape: (num_x, num_y, numpoints_f2_minus_z)

    # Finally, compute f_u
    f_u = (exp_c_u / a(u)) * integral_value.squeeze()  # Shape: (num_x, num_y, numpoints_f2_minus_z)
    # Remove the dimension with size 1
    f_u = f_u.squeeze(0)  # Now shape: [num_x, num_y, numpoints_f2_minus_z]

    print(f_u.shape)
    
    
    # Combine the results, adding f_u as the fifth dimension
    result_tensor_extended = torch.cat((result_tensor, f_u.unsqueeze(-1)), dim=-1)  # Shape: (num_x, num_y, numpoints_f2_minus_z, 5)

    return result_tensor_extended   
    
#caculate f3------------------------------------------------------------------------------------------------------------------
def f3_minus(tensor, M, N, f2_z_values_expanded, num_points=200, chunk_size=10):
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device
    

    u_values = torch.linspace(-M, theta_tensor.max().item(), num_points, device=device)
    t_values = torch.linspace(-M, theta_tensor.max().item(), num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]

    u_grid, t_grid = torch.meshgrid(u_values, t_values, indexing='ij')

    # Initialize result tensor
    result = torch.zeros(x_tensor.shape[0], x_tensor.shape[1], theta_tensor.shape[2], device=device)

    for x_chunk_start in range(0, x_tensor.shape[0], chunk_size):
        x_chunk_end = min(x_chunk_start + chunk_size, x_tensor.shape[0])
        x_chunk = x_tensor[x_chunk_start:x_chunk_end]
        y_chunk = y_tensor[x_chunk_start:x_chunk_end]
        
        # Similarly for theta_tensor if needed
        theta_chunk = theta_tensor[x_chunk_start:x_chunk_end]


        for u_chunk_start in range(0, num_points, chunk_size):
            u_chunk_end = min(u_chunk_start + chunk_size, num_points)
            u_chunk = u_grid[u_chunk_start:u_chunk_end]
            
            for t_chunk_start in range(0, num_points, chunk_size):
                t_chunk_end = min(t_chunk_start + chunk_size, num_points)
                t_chunk = t_grid[t_chunk_start:t_chunk_end]
                
                f2_chunk = f2_z_values_expanded[x_chunk_start:x_chunk_end, :, u_chunk_start:u_chunk_end,:]
             
                
                # Compute integrals for the current chunk
                inner_integral_mask = (t_chunk >= u_chunk) & (t_chunk <= theta_chunk.unsqueeze(-1).unsqueeze(-1))
                integrand_inner = (torch.exp(c(u_chunk)) / a(u_chunk)) * torch.exp(-c(t_chunk))
                
                integrand_inner_masked = integrand_inner * inner_integral_mask
                f2_chunk = f2_chunk.unsqueeze(2).expand(-1, -1, theta_tensor.shape[2], -1, -1) #20 is the shape of theta_values
                inner_sum = torch.sum(integrand_inner_masked * f2_chunk, dim=(-2, -1)) * delta_u * delta_t
                result[x_chunk_start:x_chunk_end, :, :] += inner_sum

    return result


    

def f3_plus(tensor, M, N, f2_z_values_expanded, num_points=200, chunk_size=10):
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device

    u_values = torch.linspace(theta_tensor.min().item(), N, num_points, device=device)
    t_values = torch.linspace(theta_tensor.min().item(), N, num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]

    u_grid, t_grid = torch.meshgrid(u_values, t_values, indexing='ij')

    # Initialize result tensor
    result = torch.zeros(x_tensor.shape[0], x_tensor.shape[1], theta_tensor.shape[2], device=device)

    for x_chunk_start in range(0, x_tensor.shape[0], chunk_size):
        x_chunk_end = min(x_chunk_start + chunk_size, x_tensor.shape[0])
        x_chunk = x_tensor[x_chunk_start:x_chunk_end]
        y_chunk = y_tensor[x_chunk_start:x_chunk_end]
        
        # Similarly for theta_tensor if needed
        theta_chunk = theta_tensor[x_chunk_start:x_chunk_end]

        for u_chunk_start in range(0, num_points, chunk_size):
            u_chunk_end = min(u_chunk_start + chunk_size, num_points)
            u_chunk = u_grid[u_chunk_start:u_chunk_end]
            
            for t_chunk_start in range(0, num_points, chunk_size):
                t_chunk_end = min(t_chunk_start + chunk_size, num_points)
                t_chunk = t_grid[t_chunk_start:t_chunk_end]
                
                f2_chunk = f2_z_values_expanded[x_chunk_start:x_chunk_end, :, u_chunk_start:u_chunk_end,:]
             
                # Compute integrals for the current chunk
                inner_integral_mask = (t_chunk >= theta_chunk.unsqueeze(-1).unsqueeze(-1)) & (t_chunk <= u_chunk)
                integrand_inner = (torch.exp(c(u_chunk)) / a(u_chunk)) * torch.exp(-c(t_chunk))
                
                integrand_inner_masked = integrand_inner * inner_integral_mask
                f2_chunk = f2_chunk.unsqueeze(2).expand(-1, -1, theta_tensor.shape[2], -1, -1)  # Adjust expansion to match dimensions
                
                inner_sum = torch.sum(integrand_inner_masked * f2_chunk, dim=(-2, -1)) * delta_u * delta_t
                result[x_chunk_start:x_chunk_end, :, :] += inner_sum

    return result
    



def f3_u_minus(result_tensor, M, N, f2_z_values_expanded_1, num_intervals=200, chunk_size=5):
    # Extract (x, y, theta, u) from result_tensor
    x = result_tensor[..., 0]
    y = result_tensor[..., 1]
    theta = result_tensor[..., 2]
    u = result_tensor[..., 3]
    n0=20
    num_chunks = (x.shape[0] + chunk_size - 1) // chunk_size

    # Get device information
    device = result_tensor.device
    
    z= torch.linspace(-M,N,num_intervals).to(device)
    z_values = z.unsqueeze(0).unsqueeze(0).expand(n0,n0,num_intervals)
    
    # Generate t_values from 0 to 1 for intervals
    t_values = torch.linspace(0, 1, num_intervals).view(-1, 1, 1, 1).to(device)

    # Use broadcasting to generate t_values within the range [u, theta]
    u_tensor = u.unsqueeze(0)  # Shape: (1, num_x, num_y, numpoints_f2_minus_z)
    t_values = u_tensor + (theta.unsqueeze(0) - u_tensor) * t_values  # Shape: (num_intervals, num_x, num_y, numpoints_f2_minus_z)

    
    new_t=t_values[0]
    new_z=z_values
    new_f2=f2_z_values_expanded_1[0]

       
    tensorf2_0 = torch.empty_like(new_t)

    for p in range(new_t.shape[0]):
        for q in range(new_t.shape[1]):
            for t in range(new_t.shape[2]):
                distances = torch.abs(new_t[p, q, t] - new_z[p, q, :])
                z_best_index = torch.argmin(distances)
                tensorf2_0[p, q, t] = new_f2[p, q, z_best_index] 

    tensorf2=tensorf2_0.unsqueeze(0).expand(num_intervals,n0,n0,num_intervals).contiguous()


    result = []

    for i in range(num_chunks):
        start = i * chunk_size
        end = min((i + 1) * chunk_size, x.shape[0])

        x_chunk = x[start:end]
        y_chunk = y[start:end]
        theta_chunk = theta[start:end]
        u_chunk = u[start:end]
        f2_chunk = tensorf2[:, start:end, :, :]
        t_chunk = t_values[:, start:end, :, :]


        # Calculate dt
        dt = (theta_chunk - u_chunk) / num_intervals  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)

        # Calculate exp(c(u)) and exp(-c(t))
        exp_c_u = torch.exp(c(u_chunk)).unsqueeze(0)  # Shape: (1, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Calculate exp(-c(t))
        exp_minus_c_t = torch.exp(-c(t_chunk))  # Shape: (num_intervals, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Compute the integrand
        integrand = exp_minus_c_t * f2_chunk  # Shape: (num_intervals, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Sum over the t dimension and multiply by dt to get the integral value
        integral_value = torch.sum(integrand, dim=0) * dt.unsqueeze(0)  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)

        # Compute f_u
        f_u = (exp_c_u / a(u_chunk)) * integral_value.squeeze()  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)
        f_u = f_u.squeeze(0)  # Remove the extra dimension

        # Combine the results, adding f_u as the fifth dimension
        result_chunk_extended = torch.cat((result_tensor[start:end], f_u.unsqueeze(-1)), dim=-1)  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z, 5)

        result.append(result_chunk_extended)

    return torch.cat(result, dim=0)
    



def f3_u_plus(result_tensor, M, N, f2_z_values_expanded_1, num_intervals=200, chunk_size=5):
    # Extract (x, y, theta, u) from result_tensor
    x = result_tensor[..., 0]
    y = result_tensor[..., 1]
    theta = result_tensor[..., 2]
    u = result_tensor[..., 3]
    n0=20

    # Get device information
    device = result_tensor.device

    z= torch.linspace(-M,N,num_intervals).to(device)
    z_values = z.unsqueeze(0).unsqueeze(0).expand(n0,n0,num_intervals)
    
    # Split the computation into chunks
    num_chunks = (x.shape[0] + chunk_size - 1) // chunk_size
    
    # Generate t_values from 0 to 1 for intervals
    t_values = torch.linspace(0, 1, num_intervals, device=device).view(-1, 1, 1, 1)  # Shape: (num_intervals, 1, 1, 1)

    # Use broadcasting to generate t_values within the range [theta_chunk, u_chunk]
    theta_tensor = theta.unsqueeze(0)  # Shape: (1, num_x_chunk, num_y, numpoints_f2_minus_z)
    t_values = theta_tensor + (u.unsqueeze(0) - theta_tensor) * t_values  # Shape: (num_intervals, num_x_chunk, num_y, numpoints_f2_minus_z)   
    
    new_t=t_values[0]
    new_z=z_values
    new_f2=f2_z_values_expanded_1[0]

       
    tensorf2_0 = torch.empty_like(new_t)

    for p in range(new_t.shape[0]):
        for q in range(new_t.shape[1]):
            for t in range(new_t.shape[2]):
                distances = torch.abs(new_t[p, q, t] - new_z[p, q, :])
                z_best_index = torch.argmin(distances)
                tensorf2_0[p, q, t] = new_f2[p, q, z_best_index] 

    tensorf2=tensorf2_0.unsqueeze(0).expand(num_intervals,n0,n0,num_intervals).contiguous()
    
    result = []

    for i in range(num_chunks):
        start = i * chunk_size
        end = min((i + 1) * chunk_size, x.shape[0])

        x_chunk = x[start:end]
        y_chunk = y[start:end]
        theta_chunk = theta[start:end]
        u_chunk = u[start:end]
        f2_chunk = tensorf2[:, start:end, :, :]
        t_chunk = t_values[:, start:end, :, :]



        # Calculate dt
        dt = (u_chunk - theta_chunk) / num_intervals  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)

        # Calculate exp(c(u)) and exp(-c(t))
        exp_c_u = torch.exp(c(u_chunk)).unsqueeze(0)  # Shape: (1, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Compute exp(-c(t)) for all t_values
        exp_minus_c_t = torch.exp(-c(t_chunk))  # Shape: (num_intervals, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Compute the integrand
        integrand = exp_minus_c_t * f2_chunk  # Shape: (num_intervals, num_x_chunk, num_y, numpoints_f2_minus_z)

        # Sum over the t dimension and multiply by dt to get the integral value
        integral_value = torch.sum(integrand, dim=0) * dt.unsqueeze(0)  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)

        # Finally, compute f_u
        f_u = (exp_c_u / a(u_chunk)) * integral_value.squeeze()  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z)
        f_u = f_u.squeeze(0)  # Remove the extra dimension

        # Combine the results, adding f_u as the fifth dimension
        result_chunk_extended = torch.cat((result_tensor[start:end], f_u.unsqueeze(-1)), dim=-1)  # Shape: (num_x_chunk, num_y, numpoints_f2_minus_z, 5)

        result.append(result_chunk_extended)

    return torch.cat(result, dim=0)
    
    

##################################################################################################################################################################
def create_combined_tensor0(x_values, y_values, min_theta_values):
    x_grid, y_grid = torch.meshgrid(x_values, y_values, indexing='ij')
    combined_tensor0 = torch.stack((x_grid, y_grid, min_theta_values), dim=-1)
    return combined_tensor0, x_grid, y_grid 


def create_combined_tensor_f2(combined_tensor0, M, N, numpoints_f2, f2_type='minus'):
    device = combined_tensor0.device
    theta_values = combined_tensor0[..., 2].unsqueeze(-1)  # (num_x, num_y, 1)
    expanded_theta = theta_values.expand(-1, -1, numpoints_f2)  # (num_x, num_y, numpoints_f2)
    
    if f2_type == 'minus':
        u_values = torch.linspace(-M, 1, numpoints_f2).to(device)
        u_values_tensor = -M + (u_values - (-M)) * (expanded_theta - (-M)) / (1 - (-M))
        step_size = (expanded_theta - (-M)) / (numpoints_f2 - 1)
    else:
        u_values = torch.linspace(-M, N, numpoints_f2).to(device)
        u_values_tensor = theta_values + (u_values - (-M)) * (N - theta_values) / (N - (-M))
        step_size = (N - expanded_theta) / (numpoints_f2 - 1)
    
    expanded_x = combined_tensor0[..., 0].unsqueeze(-1).expand(-1, -1, numpoints_f2)
    expanded_y = combined_tensor0[..., 1].unsqueeze(-1).expand(-1, -1, numpoints_f2)
    
    combined_tensor = torch.stack((expanded_x, expanded_y, expanded_theta, u_values_tensor), dim=-1)
    return combined_tensor, step_size

def separate_tensor_by_z_theta(x_grid, y_grid, min_theta_values, z_values, n0, n2):
    device = x_grid.device
    x_grid_expanded = x_grid.unsqueeze(-1).expand(n0, n0, n2)
    y_grid_expanded = y_grid.unsqueeze(-1).expand(n0, n0, n2)
    min_theta_values_expanded = min_theta_values.unsqueeze(-1).expand(n0, n0, n2)
    z_grid = z_values.unsqueeze(0).unsqueeze(0).expand(n0, n0, n2)
    
    combined_tensor1 = torch.stack((x_grid_expanded, y_grid_expanded, min_theta_values_expanded, z_grid), dim=-1)
    z_values_expanded = combined_tensor1[..., 3]
    theta_values_expanded = combined_tensor1[..., 2]
    
    mask_z_leq_theta = z_values_expanded <= theta_values_expanded
    
    grid_1 = combined_tensor1.clone()
    grid_1[~mask_z_leq_theta] = float('inf')
    
    grid_2 = combined_tensor1.clone()
    grid_2[mask_z_leq_theta] = float('inf')
    
    return grid_1, grid_2

def calculate_f_u(f_u_func, combined_tensor, M, N, f2_z_values_expanded, step_size):
    device = combined_tensor.device
    f_u = f_u_func(combined_tensor, M, N, f2_z_values_expanded)
    return torch.cat((f_u, step_size.unsqueeze(-1)), dim=-1)

def compute_result_grid(grid, f_u, compare_func):
    device= grid.device
    n0, n1, n2 = grid.shape[:3]
    result_grid = torch.zeros(n0, n1, n2).to(grid.device)

    # Extract grid components
    x_grid_expanded = grid[..., 0]  
    y_grid_expanded = grid[..., 1]  
    theta_values_expanded = grid[..., 2]  
    z_values_expanded = grid[..., 3]  

    # Iterate over the grid dimensions
    for i in range(n0):
        for j in range(n1):
            for k in range(n2):
                x = x_grid_expanded[i, j, k]
                y = y_grid_expanded[i, j, k]
                theta = theta_values_expanded[i, j, k]
                z = z_values_expanded[i, j, k]

                # Find matching (x, y, theta) in f_u
                mask = (f_u[..., 0] == x) & (f_u[..., 1] == y) & (f_u[..., 2] == theta)
                matching_f_u = f_u[mask]

                # Filter values based on compare_func
                u_values = matching_f_u[..., 3]
                u_mask = compare_func(u_values, z)
                filtered_f_u = matching_f_u[u_mask]

                # Compute the sum of f_u * delta_u
                result = torch.sum(filtered_f_u[..., 4] * filtered_f_u[..., 5])

                # Store the result
                result_grid[i, j, k] = result

    return result_grid

def calculate_f3_minus_z(f2_z_values_expanded, combined_tensor2, M, N, grid_1, step_size1):
    device = combined_tensor2.device
    f_u_1 = calculate_f_u(f3_u_minus, combined_tensor2, M, N, f2_z_values_expanded, step_size1)
    result_grid1 = compute_result_grid(grid_1, f_u_1, lambda u, z: u < z)
    return torch.cat((grid_1, result_grid1.unsqueeze(-1)), dim=-1)

def calculate_f3_plus_z(f2_z_values_expanded, combined_tensor4, M, N, grid_2, step_size2):
    device = combined_tensor4.device
    f_u_2 = calculate_f_u(f3_u_plus, combined_tensor4, M, N, f2_z_values_expanded, step_size2)
    result_grid2 = compute_result_grid(grid_2, f_u_2, lambda u, z: u > z)
    return torch.cat((grid_2, result_grid2.unsqueeze(-1)), dim=-1)

def get_f3_z(f3_minus_z, f3_plus_z):
    inf_mask = torch.isinf(f3_minus_z[:, :, :, 0]).unsqueeze(-1)
    f3_z = f3_minus_z.clone()
    f3_z[inf_mask.expand_as(f3_z)] = f3_plus_z[inf_mask.expand_as(f3_z)]
    torch.save(f3_z, 'f3_z.pt')
    return f3_z

# Main function to call the above
def main_function(n0, n1, n2, combined_tensor2, combined_tensor4, M, N, f2_z_values_expanded_1, grid_1, grid_2, step_size1, step_size2):
    # Compute f3_minus_z
    f3_minus_z = calculate_f3_minus_z(f2_z_values_expanded_1, combined_tensor2, M, N, grid_1, step_size1)

    # Compute f3_plus_z
    f3_plus_z = calculate_f3_plus_z(f2_z_values_expanded_1, combined_tensor4, M, N, grid_2, step_size2)

    # Combine f3_minus_z and f3_plus_z into f3_z
    f3_z = get_f3_z(f3_minus_z, f3_plus_z)

    return f3_z


def calculate_knn(f3_z, f2_z):

    x = f3_z[..., 0]  # [n0, n0, n2]
    y = f3_z[..., 1]  # [n0, n0, n2]
    z = f3_z[..., 3]  # [n0, n0, n2]
    f3_z_values = f3_z[..., 4]  


    f2_z_values = f2_z[..., 4]  


    f3_divide_f2 = f3_z_values / f2_z_values
    f3_divide_f2 = torch.nan_to_num(f3_divide_f2, nan=float('-inf'))

    result = torch.cat((x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1), f3_divide_f2.unsqueeze(-1)), dim=-1)

    sup_z = torch.max(f3_divide_f2, dim=-1).values

    x_less_than_y = x[..., 0] < y[..., 0]
    filtered_sup_z = torch.where(x_less_than_y, sup_z, torch.tensor(float('inf'), device=sup_z.device))

    inf_x_less_than_y = torch.min(filtered_sup_z)

    knn2 = 1 / inf_x_less_than_y

    return knn2



def calculate_knn_series(KNN, n0, n1, n2, M, N, tensor, x_values, y_values, theta_values, numpoints_f2_minus_z):
    device = tensor.device
    knn_results = {}


    f2_z = torch.load('f2_z.pt')
    f2_z_values = f2_z[..., -1]
    f_z_values_expanded = f2_z_values.unsqueeze(-1).expand(-1, -1, -1, f2_z_values.shape[-1])
    f_z_values_expanded_1 = f2_z_values.unsqueeze(0).repeat(n2, 1, 1, 1)

    f_i = f2_z  

    for i in range(2, KNN + 1):

        f_minus_values = f3_minus(tensor, M, N, f_z_values_expanded).to(device)
        f_plus_values = f3_plus(tensor, M, N, f_z_values_expanded).to(device)

        f_theta = torch.abs(f_plus_values - f_minus_values).to(device)
        min_f_theta_values, min_theta_indices = torch.min(f_theta, dim=-1)
        min_theta_values = theta_values[min_theta_indices].to(device)
        

        combined_tensor0, x_grid, y_grid = create_combined_tensor0(x_values, y_values, min_theta_values)
        combined_tensor2, step_size1 = create_combined_tensor_f2(combined_tensor0, M, N, numpoints_f2_minus_z, f2_type='minus')
        combined_tensor4, step_size2 = create_combined_tensor_f2(combined_tensor0, M, N, numpoints_f2_minus_z, f2_type='plus')

        z_values = torch.linspace(-M, N, n2).to(device)
        grid_1, grid_2 = separate_tensor_by_z_theta(x_grid, y_grid, min_theta_values, z_values, n0, n2)

        f_i_plus_1 = main_function(n0, n1, n2, combined_tensor2, combined_tensor4, M, N, f_z_values_expanded_1, grid_1, grid_2, step_size1, step_size2)

        knn_value = calculate_knn(f_i_plus_1, f_i)
        knn_results[f'knn-{i}'] = knn_value.item()  

        f_i = f_i_plus_1
        f_z_values = f_i[...,-1]
        f_z_values_expanded = f_z_values.unsqueeze(-1).expand(-1, -1, -1, f_z_values.shape[-1])
        f_z_values_expanded_1 = f_z_values.unsqueeze(0)
        f_z_values_expanded_1 = f_z_values_expanded_1.repeat(n2, 1, 1, 1)



        print(f'knn-{i}: {knn_value}')

    return knn_results
